import torch
import numpy as np


class CoupleData:
    def __init__(self,seeds=43,size=1000) -> None:
        self.size = size
        self.seeds = seeds
        

    def generatate_coup(self,totensor: bool = False):
        np.random.seed(self.seeds)
        U = np.random.uniform(-1,1,size=self.size)
        A1 = 0.5*(U+3) + np.random.normal(0,1,size=self.size)
        A2 = 0.5*(np.tanh(U)+3) + np.random.normal(0,1,size=self.size)
        A3 = 0.5*(np.sin(np.pi/8*U)+3) + np.random.normal(0,1,size=self.size)
        A4 = 0.5*(1.0/(1.0+np.exp(U))+3) + np.random.normal(0,1,size=self.size)
        A5 = 0.5*(np.cos(np.pi/8*U)+3) + np.random.normal(0,1,size=self.size)

        Y1 = 2*np.sin(1.4*A1 +2*A3**2)+0.5*(A2+A4**2+A5)+A3**3 + U + np.random.normal(0,1,size=self.size)
        Y2 = -2*np.cos(1.8*A2)+1.5*A4**2+ U + np.random.normal(0,1,size=self.size)
        Y3 = 0.7*A3**2+ 1.2*A4 + U + np.random.normal(0,1,size=self.size)
        Y4 = 0.2*np.exp(-A1+1)+ 1.4*A5**2 + U + np.random.normal(0,1,size=self.size)

        data = np.c_[A1,A2,A3,A4,A5,Y1,Y2,Y3,Y4]

        if totensor:
            data = torch.tensor(data, dtype=torch.float32)

        return data
    
    @staticmethod
    def generate_test(size,seed=43,totensor=False) -> None:
        np.random.seed(seed)
        U = np.random.uniform(-1,1,size=size)
        A1 = 0.5*(U+3) + np.random.normal(0,1,size=size)
        A2 = 0.5*(np.tanh(U)+3) + np.random.normal(0,1,size=size)
        A3 = 0.5*(np.sin(np.pi/8*U)+3) + np.random.normal(0,1,size=size)
        A4 = 0.5*(1.0/(1.0+np.exp(U))+3) + np.random.normal(0,1,size=size)
        A5 = 0.5*(np.cos(np.pi/8*U)+3) + np.random.normal(0,1,size=size)

        Y1 = 2*np.sin(1.4*A1 +2*A3**2  )+0.5*(A2+A4**2+A5)+A3**3 + U+ np.random.normal(0,1,size=size)
        Y2 = -2*np.cos(1.8*A2)+1.5*A4**2+ U + np.random.normal(0,1,size=size)
        Y3 = 0.7*A3**2+ 1.2*A4 + U + np.random.normal(0,1,size=size)
        Y4 = 0.2*np.exp(-A1+1)+ 1.4*A5**2 + U + np.random.normal(0,1,size=size)

        data = np.c_[A1,A2,A3,A4,A5,Y1,Y2,Y3,Y4]

        if totensor:
            data = torch.tensor(data, dtype=torch.float32)

        return data
    
    @staticmethod
    def generate_effect_exm1(a,b,c):
        A2 = np.linspace(a, b, c)
        U = np.random.uniform(-1,1,size=10000)
        A4 = 0.5*(1.0/(1.0+np.exp(U))+3) + np.random.normal(0,1,size=10000)
        treatment = np.array([np.mean(-2*np.cos(1.8*a)+1.5*A4**2+U) for a in A2])
        return A2,treatment
    
    @staticmethod
    def generate_effect_exm2(a,b,c):
        U = np.random.uniform(-1,1,size=10000)
        A1 = np.linspace(a, b, c)
        A2 = 0.5*(np.tanh(U)+3) + np.random.normal(0,1,size=10000)
        A3 = np.linspace(a, b, c)
        A4 = 0.5*(1.0/(1.0+np.exp(U))+3) + np.random.normal(0,1,size=10000)
        A5 = 0.5*(np.cos(np.pi/8*U)+3) + np.random.normal(0,1,size=10000)
        treatment = np.array([np.mean( 2*np.sin(1.4*a1+2*a3**2) +0.5*(A2+A4**2+A5) + a3**3 +U) for a1,a3 in zip(A1,A3)])
        return A1,treatment
        

    
    @staticmethod
    def generate_effect_exm4(a,b,c):
        A1 = np.linspace(a, b, c)
        A5 = np.linspace(a, b, c)
        U = np.random.uniform(-1,1,size=10000)
        treatment = np.array([np.mean(0.2*np.exp(-a1+1) + 1.4*a5**2 + U) for a1,a5 in zip(A1,A5)])
        return A1,treatment
    

    @staticmethod
    def generate_effect_exm5(a,b,c):
        U = np.random.uniform(-1,1,size=10000)
        A1 = 0.5*(U+3) + np.random.normal(0,1,size=10000)
        A2 = 0.5*(np.tanh(U)+3) + np.random.normal(0,1,size=10000)
        A3 = np.linspace(a, b, c)
        A4 = 0.5*(1.0/(1.0+np.exp(U))+3) + np.random.normal(0,1,size=10000)
        A5 = 0.5*(np.cos(np.pi/8*U)+3) + np.random.normal(0,1,size=10000)
        treatment = np.array([np.mean( 2*np.sin(1.4*A1+2*a3**2) +0.5*(A2+A4**2+A5) + a3**3 +U) for a3 in A3])
        return A3,treatment
    

    @staticmethod
    def generate_effect_exm6(a,b,c):
        U = np.random.uniform(-1,1,size=10000)
        A1 = 0.5*(U+3) + np.random.normal(0,1,size=10000)
        A2 = 0.5*(np.tanh(U)+3) + np.random.normal(0,1,size=10000)
        A3 = np.linspace(a, b, c)
        A4 = np.linspace(a, b, c)
        A5 = 0.5*(np.cos(np.pi/8*U)+3) + np.random.normal(0,1,size=10000)
        treatment = np.array([np.mean( 2*np.sin(1.4*A1+2*a3**2) +0.5*(A2+a4**2+A5) + a3**3 +U) for a3,a4 in zip(A3,A4)])
        return A3,treatment



